{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9662657d-0ee6-408d-94fe-da71539dcfd5",
   "metadata": {},
   "source": [
    "# AutoICE - test model and prepare upload package\n",
    "This notebook tests the 'best_model', created in the quickstart notebook, with the tests scenes exempt of reference data. The model outputs are stored per scene and chart in an xarray Dataset in individual Dataarrays. The xarray Dataset is saved and compressed in an .nc file ready to be uploaded to the AI4EO.eu platform. Finally, the scene chart inference is shown.\n",
    "\n",
    "The first cell imports necessary packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1cabf3fe-44c3-42f0-b061-5a70b64379be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'../dataset/train'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# -- Built-in modules -- #\n",
    "import os\n",
    "\n",
    "# -- Third-part modules -- #\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import xarray as xr\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# --Proprietary modules -- #\n",
    "from functions import chart_cbar, r2_metric, f1_metric, compute_metrics\n",
    "from loaders import AI4ArcticChallengeTestDataset\n",
    "from unet import UNet\n",
    "from utils import CHARTS, SIC_LOOKUP, SOD_LOOKUP, FLOE_LOOKUP, SCENE_VARIABLES, colour_str\n",
    "%store -r train_options\n",
    "\n",
    "train_options['path_to_processed_data']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b270975-ef64-45c0-bfa0-5f3d48504ab5",
   "metadata": {},
   "source": [
    "### Setup of the GPU resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1a3e19a1-76d0-4047-b4a4-8b5f8af19bb9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;32mGPU available!\u001b[0m\n",
      "Total number of available devices:  \u001b[0;33m1\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# Get GPU resources.\n",
    "if torch.cuda.is_available():\n",
    "    print(colour_str('GPU available!', 'green'))\n",
    "    print('Total number of available devices: ', colour_str(torch.cuda.device_count(), 'orange'))\n",
    "    device = torch.device(f\"cuda:{train_options['gpu_id']}\")\n",
    "\n",
    "else:\n",
    "    print(colour_str('GPU not available.', 'red'))\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f8d365e-0a78-43f2-9182-776148c80d31",
   "metadata": {},
   "source": [
    "### Load the model and stored parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "259d1765",
   "metadata": {},
   "outputs": [],
   "source": [
    "weights=torch.load('best_model')['model_state_dict']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7df8fc3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_2={}\n",
    "for key,value in weights.items():\n",
    "    weights_2[key[7:]]=value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7e164255-8528-46b3-9ce6-6bc27c0effae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model.\n",
      "Model successfully loaded.\n"
     ]
    }
   ],
   "source": [
    "print('Loading model.')\n",
    "# Setup U-Net model, adam optimizer, loss function and dataloader.\n",
    "net = UNet(options=train_options).to(device)\n",
    "net.load_state_dict(weights_2)\n",
    "print('Model successfully loaded.')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d16af23-a901-4ca6-8d75-7a54d66a1eaa",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Prepare the scene list, dataset and dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "28f92af3-32ab-4752-832a-29b5972ac375",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setup ready\n"
     ]
    }
   ],
   "source": [
    "with open(train_options['path_to_env'] + 'datalists/testset.json') as file:\n",
    "    train_options['test_list'] = json.loads(file.read())\n",
    "train_options['test_list'] = [file[17:32] + '_' + file[77:80] + '_prep.nc' for file in train_options['test_list']]\n",
    "train_options['path_to_processed_data'] = '../data/ai4arctic/test/'  # The test data is stored in a separate folder inside the training data.\n",
    "upload_package = xr.Dataset()  # To store model outputs.\n",
    "dataset = AI4ArcticChallengeTestDataset(options=train_options, files=train_options['test_list'], test=True)\n",
    "asid_loader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=train_options['num_workers_val'], shuffle=False)\n",
    "print('Setup ready')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "11e0c3da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('filename.pickle', 'wb') as handle:\n",
    "    pickle.dump(train_options, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "20730e27-e0cc-4c1a-87a8-f2145d32f5dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "243ad5ca818f4924823c16fbeb27ac1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyError",
     "evalue": "Caught KeyError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 1336, in _construct_dataarray\n    variable = self._variables[name]\nKeyError: 'SIC'\n\nDuring handling of the above exception, another exception occurred:\n\nTraceback (most recent call last):\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py\", line 302, in _worker_loop\n    data = fetcher.fetch(index)\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in fetch\n    data = self.dataset[possibly_batched_index]\n  File \"/media/pc2041/data/vip_lab/ai4arctic_challenge/loaders.py\", line 255, in __getitem__\n    x, y = self.prep_scene(scene)\n  File \"/media/pc2041/data/vip_lab/ai4arctic_challenge/loaders.py\", line 221, in prep_scene\n    size=scene['SIC'].values.shape,\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 1427, in __getitem__\n    return self._construct_dataarray(key)\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 1338, in _construct_dataarray\n    _, name, variable = _get_virtual_variable(self._variables, name, self.dims)\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 177, in _get_virtual_variable\n    raise KeyError(key)\nKeyError: 'SIC'\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m os\u001b[39m.\u001b[39mmakedirs(\u001b[39m'\u001b[39m\u001b[39minference\u001b[39m\u001b[39m'\u001b[39m, exist_ok\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m      3\u001b[0m net\u001b[39m.\u001b[39meval()\n\u001b[0;32m----> 4\u001b[0m \u001b[39mfor\u001b[39;00m inf_x, _, masks, scene_name \u001b[39min\u001b[39;00m tqdm(iterable\u001b[39m=\u001b[39masid_loader, total\u001b[39m=\u001b[39m\u001b[39mlen\u001b[39m(train_options[\u001b[39m'\u001b[39m\u001b[39mtest_list\u001b[39m\u001b[39m'\u001b[39m]), colour\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mgreen\u001b[39m\u001b[39m'\u001b[39m, position\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m):\n\u001b[1;32m      5\u001b[0m     scene_name \u001b[39m=\u001b[39m scene_name[:\u001b[39m19\u001b[39m]  \u001b[39m# Removes the _prep.nc from the name.\u001b[39;00m\n\u001b[1;32m      6\u001b[0m     torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mempty_cache()\n",
      "File \u001b[0;32m~/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/tqdm/notebook.py:259\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    257\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m    258\u001b[0m     it \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39m(tqdm_notebook, \u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__iter__\u001b[39m()\n\u001b[0;32m--> 259\u001b[0m     \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m it:\n\u001b[1;32m    260\u001b[0m         \u001b[39m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m    261\u001b[0m         \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m    262\u001b[0m \u001b[39m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1192\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[1;32m   1194\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m     \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m iterable:\n\u001b[1;32m   1196\u001b[0m         \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m   1197\u001b[0m         \u001b[39m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m   1198\u001b[0m         \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/utils/data/dataloader.py:652\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    649\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    650\u001b[0m     \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m    651\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset()  \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 652\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m    653\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m    654\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m    655\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m    656\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1347\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1345\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   1346\u001b[0m     \u001b[39mdel\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_task_info[idx]\n\u001b[0;32m-> 1347\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_process_data(data)\n",
      "File \u001b[0;32m~/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1373\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._process_data\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m   1371\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_try_put_index()\n\u001b[1;32m   1372\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, ExceptionWrapper):\n\u001b[0;32m-> 1373\u001b[0m     data\u001b[39m.\u001b[39;49mreraise()\n\u001b[1;32m   1374\u001b[0m \u001b[39mreturn\u001b[39;00m data\n",
      "File \u001b[0;32m~/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/_utils.py:461\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    457\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m    458\u001b[0m     \u001b[39m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m    459\u001b[0m     \u001b[39m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m    460\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(msg) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39m\n\u001b[0;32m--> 461\u001b[0m \u001b[39mraise\u001b[39;00m exception\n",
      "\u001b[0;31mKeyError\u001b[0m: Caught KeyError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 1336, in _construct_dataarray\n    variable = self._variables[name]\nKeyError: 'SIC'\n\nDuring handling of the above exception, another exception occurred:\n\nTraceback (most recent call last):\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py\", line 302, in _worker_loop\n    data = fetcher.fetch(index)\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in fetch\n    data = self.dataset[possibly_batched_index]\n  File \"/media/pc2041/data/vip_lab/ai4arctic_challenge/loaders.py\", line 255, in __getitem__\n    x, y = self.prep_scene(scene)\n  File \"/media/pc2041/data/vip_lab/ai4arctic_challenge/loaders.py\", line 221, in prep_scene\n    size=scene['SIC'].values.shape,\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 1427, in __getitem__\n    return self._construct_dataarray(key)\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 1338, in _construct_dataarray\n    _, name, variable = _get_virtual_variable(self._variables, name, self.dims)\n  File \"/home/pc2041/miniconda3/envs/env_ai4arctic/lib/python3.8/site-packages/xarray/core/dataset.py\", line 177, in _get_virtual_variable\n    raise KeyError(key)\nKeyError: 'SIC'\n"
     ]
    }
   ],
   "source": [
    "print('Testing.')\n",
    "os.makedirs('inference', exist_ok=True)\n",
    "net.eval()\n",
    "for inf_x, _, masks, scene_name in tqdm(iterable=asid_loader, total=len(train_options['test_list']), colour='green', position=0):\n",
    "    scene_name = scene_name[:19]  # Removes the _prep.nc from the name.\n",
    "    torch.cuda.empty_cache()\n",
    "    inf_x = inf_x.to(device, non_blocking=True)\n",
    "\n",
    "    with torch.no_grad(), torch.cuda.amp.autocast():\n",
    "        output = net(inf_x)\n",
    "\n",
    "    for chart in train_options['charts']:\n",
    "        output[chart] = torch.argmax(output[chart], dim=1).squeeze().cpu().numpy()\n",
    "        upload_package[f\"{scene_name}_{chart}\"] = xr.DataArray(name=f\"{scene_name}_{chart}\", data=output[chart].astype('uint8'),\n",
    "                                                               dims=(f\"{scene_name}_{chart}_dim0\", f\"{scene_name}_{chart}_dim1\"))\n",
    "\n",
    "    # - Show the scene inference.\n",
    "    fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))\n",
    "    for idx, chart in enumerate(train_options['charts']):\n",
    "        ax = axs[idx]\n",
    "        output[chart] = output[chart].astype(float)\n",
    "        output[chart][masks] = np.nan\n",
    "        ax.imshow(output[chart], vmin=0, vmax=train_options['n_classes'][chart] - 2, cmap='jet', interpolation='nearest')\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        chart_cbar(ax=ax, n_classes=train_options['n_classes'][chart], chart=chart, cmap='jet')\n",
    "    \n",
    "    plt.suptitle(f\"Scene: {scene_name}\", y=0.65)\n",
    "    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.5, hspace=-0)\n",
    "    fig.savefig(f\"inference/{scene_name}.png\", format='png', dpi=128, bbox_inches=\"tight\")\n",
    "    plt.close('all')\n",
    "\n",
    "\n",
    "# - Save upload_package with zlib compression.\n",
    "print('Saving upload_package. Compressing data with zlib.')\n",
    "compression = dict(zlib=True, complevel=1)\n",
    "encoding = {var: compression for var in upload_package.data_vars}\n",
    "upload_package.to_netcdf('upload_package.nc', mode='w', format='netcdf4', engine='netcdf4', encoding=encoding)\n",
    "print('Testing completed.')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.15 ('env_ai4arctic')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  },
  "vscode": {
   "interpreter": {
    "hash": "a55a9ce189c6ae0384f136447c0500d14026df24b5a84c83cb225c5750d9692c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
